-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use complex tensors in phase_vocoder #758
Conversation
>>> # (channel, freq, time, complex=2) | ||
>>> complex_specgrams = torch.randn(2, freq, 300, 2) | ||
>>> # (channel, freq, time) | ||
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is neat!
Are you planning to change all the implementations of complex numbers in torchaudio? We could add this to the release if it's ready and all the references are changed. Anything to be aware of? In particular, we are not testing autograd, and we have not made any commitments yet. Would this support autograd? |
synced offline:
|
89ba39b
to
1ebd7e7
Compare
Codecov Report
@@ Coverage Diff @@
## master #758 +/- ##
=========================================
Coverage ? 89.17%
=========================================
Files ? 32
Lines ? 2512
Branches ? 0
=========================================
Hits ? 2240
Misses ? 272
Partials ? 0
Continue to review full report at Codecov.
|
@vincentqb @mthrok The CI for this PR looks green (not sure how the Now on further look at the
Now the functions in TorchScript
Question for you: Should we add a warning in this release that
Things to figure out: Autograd requirements for torchaudio to be able to adapt complex numbers. |
How about truncating the norm of a complex number in polar coordinate? i.e. (1) convert to polar coordinate, (2) clamp the norm, (3) convert back to cartesian coordinate. This has the advantage of reducing to clamp [-x, x] for real numbers? I guess this wouldn't work as well for asymmetric interval clamp. Other cases to thing about? |
torchaudio/functional.py
Outdated
phase_acc = torch.cumsum(phase, -1) | ||
|
||
mag = alphas * norm_1 + (1 - alphas) * norm_0 | ||
|
||
real_stretch = mag * torch.cos(phase_acc) | ||
imag_stretch = mag * torch.sin(phase_acc) | ||
|
||
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) | ||
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems pretty clunky, is there not a better way of doing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed! After this PR: pytorch/pytorch#39617, we can just rewrite this as
torch.complex_polar(mag, phase_acc)
instead of:
real_stretch = mag * torch.cos(phase_acc)
imag_stretch = mag * torch.sin(phase_acc)
complex_specgrams_stretch = torch.view_as_complex(torch.stack([real_stretch, imag_stretch], dim=-1))
1ebd7e7
to
4509fe8
Compare
@@ -458,68 +458,67 @@ def phase_vocoder( | |||
factor of ``rate``. | |||
|
|||
Args: | |||
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about something like this?
complex_specgrams (Tensor): Either a real tensor of dimension of `(..., freq, time, complex=2)`
or a tensor of dimension `(..., freq, time)` with complex dtype.
We were using "complex tensor" to mean (..., complex=2)
. This is now ambiguous. What expression do you recommend to refer to a tensor of complex dtype? "tensor with a complex dtype"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah tensor with a complex dtype sounds good. However this "or" way of documenting could be problematic in case where the function takes more than one complex tensors. Perhaps in those cases, we can add a note stating that either all inputs should be real tensors or all inputs should be of complex dtype.
I think it might be nicer to add a separate example with complex dtype tensors so that it's also clear that the returned output would also be complex (if applicable) etc., especially since we are planning to switch to using complex dtype tensors in the release after the upcoming release
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do find this "or" way of discussing this a little cumbersome, and I agree this will get long if many tensors are involved. We could add a note in each, and just define the args/returns with complex dtype. We can still keep the example for clarity.
"""
We are migrating to complex-dtype tensors. For backward compatibility reason,
this function still supports the legacy convention of ending with a dimension of 2
to represent a complex tensor.
Args:
complex_specgrams (Tensor): A tensor of dimension `(..., freq, time)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Returns:
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))`
with a complex dtype.
Example
Example - Legacy
"""
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P.S. Good suggestion below for example naming
torchaudio/functional.py
Outdated
rate (float): Speed-up factor | ||
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) | ||
|
||
Returns: | ||
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)` | ||
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate))` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on the comment above, this would become:
Tensor: Complex Specgrams Stretch, represented either as a real tensor with dimension of
`(..., freq, ceil(time/rate), complex=2)` or a tensor of dimension `(..., freq, time)` with complex dtype.
thoughts?
>>> rate = 1.3 # Speed up by 30% | ||
>>> phase_advance = torch.linspace( | ||
>>> 0, math.pi * hop_length, freq)[..., None] | ||
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance) | ||
>>> x.shape # with 231 == ceil(300 / 1.3) | ||
torch.Size([2, 1025, 231, 2]) | ||
torch.Size([2, 1025, 231]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might not need to change the example. we could add a second example, or a comment next to each , 2
. other ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have an example with tensors of complex dtype so that users know how to deal with complex tensors. This is an option:
(Old API) Example:
....
(New API) Example:
...
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion :)
How about standardizing on this?
Example - New API (using tensors with complex dtype)
Example - Old API (using tensors with (..., complex=2))
For deprecations to switch between (real, 2) and complex-dtype, the options are:
I'm leaning for 1 without a warning until we are ready for deprecation. At that point, we'll just tell the user that to suppress the warning they need to convert to complex dtype. Thoughts? |
For testing, you can copy existing test into a new test file prefixed with |
Yeah I would also agree that having local flags for each function makes sense since there will be jit issues with a global flag. And once fully migrated, we should generate warning everytime a user uses the (..., complex=2) real tensor. |
I do not see an advantage or necessity to put these test suites in a separate files, unless these complex types are not supported in fbcode (which is). Especially, because this is not a new function, module or test category. Yet the test suite class is a new due to the different type, which is good enough, so putting them in the same existing class files make more sense. |
I am fine either way. We can also add a new class in the same file! @mthrok will the tests in the newly added files not run in fbcode? |
@anjali411 They do, if we add the definition for the new file. But to me, it makes sense to separate them if fbcode cannot run complex type (but I bet it can because pytorch's always the latest version). Adding the tests to existing file would just start running these tests automatically. |
torchaudio/functional.py
Outdated
|
||
Example | ||
Example - old API |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the discussion above, it is not clear which API is "old" and which one is "new". More direct to say "with real tensor input" and "with complex tensor input"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I agree that's more clear! updated
|
||
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) | ||
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) | ||
if use_complex: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an alternative to duplicating all the logic here, you could have instead taken the real tensor, viewed it as complex, and then used the complex codepath (viewing it back as real in the end). Something to consider?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a possibility however the goal is to be able to remove the code in if not use_complex
branch after a deprecation cycle and just use the code in the other branch (which has similar logic, however there are some substantial differences, e.g., padding logic).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, in my suggestion, you'd delete the real code immediately :) Anyway, this is NBD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry I misread your comment and thought you meant the other way round! yeah that sounds reasonable to me. cc. @mthrok thoughts?
on more thought, current lack of autograd support and testing for complex, might introduce bc breaking changes.
I am in favor for not duplicating the logic, however if that introduces BC breaking on real value tensor input, then I think we can wait until the autograd support arrives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's R2R with complex insides, the choice of JAX/TF convention doesn't matter, you'll always get the same gradients in the end.
dff919c
to
84d954c
Compare
84d954c
to
efe1644
Compare
@anjali411 --
This could make for an interesting new test: not to check correctness of the result, but just whether autograd runs without errors. |
Adding Zafar's changes from PR pytorch#758 to run flags + gitignore additions
To be replaced by #1410 |
1. `F.phase_vocoder` accepts Tensor with complex dtype. * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it. * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle. 2. Updated tests * librosa compatibility test for complex dtype and pseudo complex dtype * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}` * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`. * batch consistency test for `T.TimeStretch`.
This reverts commit 391be73.
This PR updates the
phase_vocoder
to use complex tensors. A local boolean flagUSE_COMPLEX
is added intorch.phase_vocoder
to detect if the input is complex or not. If the input is complex, tensors with complex dtypes will be used in the implementation as well as a complex dtype tensor will be returned.Test Plan: It adds JIT tests and new tests for complex dtype in separate files.
Documentation: reflects both the old API behavior which used
(..., complex=2)
dimension real tensors and new API behavior which uses complex dtype tensors.Deprecation Warnings will be added at the end (before the release) when all functions support complex tensors.
In a follow up PR,
torch.polar
should be added to construct complex tensors using abs and angle intorch.functional.phase_vocoder
.